# ------------------------------------------------------------------------------
# Plot marginal tests for Shift RE testing bins
# Author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bash 05_marginal_tests_plot.sh
# ------------------------------------------------------------------------------

# Libraries --------------------------------------------------------------------
library(here)
library(yaml)
library(data.table)
library(tidyverse)
library(testit) # assert()
library(glue) # glue strings
library(ggplot2)
library(ggthemes) # colorblind
library(broom) # tidy

temp <- here::here("code", "05_natural_experiment", "temp")
overnight_lab <- ""

a <- modules::use(here::here("lib", "aesthetics.R"))
u <- modules::use(here::here("lib", "util.R"))
a$get_font("Optima", here("lib", "optima.ttf"))

# Helper Functions -------------------------------------------------------------
add_labels <- function(
  gg, outcome, x_var, population, mean_outcomes, group_var,
  incl_titles = FALSE, ...
) {
  xlabel <- case_when(
     x_var == "tile_stent_or_cabg_010_tested" ~ "Percentile of Predicted Risk",
     x_var == "tile_stent_or_cabg_005_tested" ~ "Percentile of Predicted Risk",
    TRUE ~ x_var
  )
  ylabel <- case_when(
    outcome == "test_010_day" ~ "Test Rate",
    outcome == "macetrop_030_pos" ~ "MACE Rate (30 day)",
    outcome == "stent_or_cabg_010_day" ~ "Yield of Testing",
    TRUE ~ outcome
  )
  effects_type <- ifelse(grepl("_re_",  group_var ), "Random", "Fixed")
  legend_label <- "Shift Test Rate Quartile"
  plot_label <- glue("{ylabel} by {xlabel}")
  subtitle <- glue("Holdout cohort ({population}), excluding patients with chronic illness history, patients over 80")
  subtitle <- ifelse(population != "tested", glue("{subtitle}, untested same-day AMI"), subtitle)
  n <- sum(mean_outcomes$n)
  caption <- glue("N = {n}")

  gg <- gg + labs(
    x = xlabel, y = ylabel, color = legend_label, fill = legend_label
  )
  if (incl_titles) {
    gg <- gg + labs(title = title, subtitle = subtitle, caption = caption)
  }

  return(gg)
}

get_filename <- function(outcome, x_var, group_var, ...) {
  x_var_lab <- u$get_xvar_lab(x_var)
  fp <- file.path(
    temp, "plots", glue("{outcome}__by__{x_var_lab}__for__{group_var}.png")
  )
}

# Load Data --------------------------------------------------------------------
message("Loading data...")
paths <- read_yaml(here::here("lib", "filepaths.yml"))
cohort <- readRDS(file.path(temp, "shift_bins_df.rds")) %>%
  filter(!exclude) %>%
  mutate(yhat_bin = factor(tile_stent_or_cabg_005_tested)) %>%
  filter(split == "train" | split == "test")
bins_only <- cohort %>%
  select(ed_enc_id, bin_re_shift_12_full_test_010_day_inclyhat_linear)

# Plot config ------------------------------------------------------------------
message("Configuring plot parameters...")
config <- crossing(
  outcome = c(
    "test_010_day"
  ),
  x_var = c(
    "tile_stent_or_cabg_005_tested"
  ),
  group_var = c(
    "bin_re_shift_12_full_test_010_day_inclyhat_linear"
  )
)

# Plots ------------------------------------------------------------------------
message("Making plots...")
outcome_plots <- config %>%
  mutate(population = pmap(., u$get_population) %>% unlist) %>%
  mutate(mean_outcomes = pmap(., u$get_grouped_mean_outcomes, df = cohort)) %>%
  mutate(gg = pmap(., a$tile_plot, palette = a$ordered_palette)) %>%
  mutate(plot = pmap(., add_labels)) %>%
  mutate(filename = pmap(., get_filename) %>% unlist)

print(outcome_plots$mean_outcomes[1])

# Save -------------------------------------------------------------------------
message("Saving...")
outcome_plots %>%
  select(plot, filename) %>%
  pmap(ggsave, width = 10, height = 7, units = "in")

message("Done.")
